#' function to generate binary data (excluding outcome)
#'
#' Generates binary data (excluding outcome), either uniformly at random (iid 
#' Bernoulli(.5)) or such that observations will have pre-specified 
#' probabilities of satisfying each rule.
#'
#' @param varNames vector of variable names
#' @param numCases number of observations
#' @param A a rule set formated as a list of rules, with are in turn each a 
#'          vector of conditions. If left NULL, data is generated iid 
#'          Bernoulli(.5)
#' @param p_A vector of  probabilities of satisfying each rule in A
#' @return data frame of binary data
genData <- function(varNames, numCases, A=NULL, p_A = NULL){
  nrows <- numCases
  ncols <- length(varNames)
  data <- data.frame(matrix(rbinom(n=nrows*ncols, size=1, prob=.5), nrows, ncols)) # Random binary data
  colnames(data) <- varNames
  # If probability 
  if(!is.null(A)){
    for(i in 1:length(A)){
      rule <- A[[i]]
      p <- p_A[i]
      rand <-  runif(numCases)
      names <- getNames(rule)
      values <- getValues(rule)
      # if rand < p, then satisfies rule
      temp <- data[names]
      temp[(apply(data[names], 1, function(x) all(x==values))) & (rand >= p), ] <- abs(values - 1)  # if originally satisfies rule and rand >= p, change so that it is opposite  of rule
      temp[!(apply(data[names], 1, function(x) all(x==values))) & (rand < p), ] <- values # if originally didn't satisfy rule and rand < p, change so it satisfies rule
      data[names] <- temp
    }
  }
  return(data)
}

#' Generate outcome
#' 
#' Generates outcomes based on a rule set and the probability that the outcome
#' is positive conditional on whether an observation satisfies the rule set
#' 
#' @param X data frame with binary data for independent variables
#' @param A true rule set with which to generate data
#' @param p_pos probability of yn=1 if xn satisfies rule (default 1),
#'              OR vector of probabilities (p_pos[i] corresponds to A[i])
#' @param p_neg probability of yn=1 if xn does NOT satisfy rule (default 0),
#'              OR vector of probabilities (p_neg[i] corresponds to A[i])
#' @return a vector of outcomes
getY <- function(X, A, p_pos=1, p_neg = 0){
  nameSet <- list()
  valueSet <- list()
  n <- nrow(X)
  for(i in 1:length(A)){
    rule <- A[[i]]
    nameSet[[i]] <- getNames(rule)
    valueSet[[i]] <- getValues(rule)
  }
  # Outcome
  Y <- rep(0, n)
  rand <- runif(n)
  pp <- p_pos
  pn <- p_neg
  for(i in 1:length(nameSet)){
    Y_i <- apply(X[nameSet[[i]]], 1, function(x) all(x==valueSet[[i]])) # Cases that satisfy ith rule
    if(length(p_pos) > 1){
      pp <- p_pos[i]
    }
    if(length(p_neg) > 1){
      pn <- p_neg[i]
    }
    Y <- as.numeric(Y == 1 | (Y_i == 1 & rand <= pp) )# 1 if satisfies ith or any previous rule
    Y <- as.numeric(Y == 1 | (Y_i == 0 & rand <= pn) )# 1 if satisfies ith or any previous rule
  }
  return(Y)
}


#' Get predicted outcomes
#' 
#' Get the outcomes that a rule set would predict, i.e. Y=1 iff x \in A
#'
#' @param data data frame out independent variables without outcome
#' @param A rule set
#' @return predicted outcomes for each observation
getYhat <- function(data, A){
  # when p_pos=1 and p_neg=0, getY returns yhat (getY returns 1 iff x \in A)
  return(getY(data, A)) 
}


#' Format QCA solutions
#'
#' Format the output of QCA solutions to be similar to those outputted by BRS
#'
#' @param rule produced by QCA (string), e.g. "~a * b"
#' @return reformatted rule (vector), e.g. c("a_neg", b)
reformat <- function(rule){
  splitRule <- tolower(strsplit(rule, "[*]")[[1]])
  for(i in 1:length(splitRule)){
    # Check if condition is positive or negative
    splitCond <- strsplit(splitRule[i], split="")[[1]]
    if(length(splitCond) == 4){ # Check if there is a tilde (negative condition)
      splitRule[i] <- paste(c(splitCond[2:4], "_neg"), collapse="")
    }
  }
  return(splitRule)
}


#' Get variable names for a rule
#'
#' Get the names of variables as they appear in the data corresponding to the
#' conditions in a rule
#'
#' @param rule a rule, formatted as a vector of conditions
#' @return vector of names of variables correspoding to the conditions in rule
getNames <- function(rule){
  split <- strsplit(rule, "_")
  names <- c()
  for(i in 1:length(split)){
    names[i] <- paste(split[[i]][split[[i]] != "neg"], collapse="_")
  }
  return(names)
}


#' Get rule values
#'
#' Get the values of each condition in a rule
#'
#' @param rule a rule, formatted as a vector of conditions
#' @return vector of values for each condition in rule
getValues <- function(rule){
  split <- strsplit(rule, "_")
  values <- c()
  for(i in 1:length(split)){
    if(length(split[[i]]) == 3){
      values[i] <- 0
    } else {
      values[i] <- 1
    }
  }
  return(values)
}


#' Simplify condition
#'
#' Simplify condition if possible, based on supplied equivalence classes
#'   e.g., "not low education" = "med or high education"
#'
#' @param cond the condition to simplify, of the form  var_val1
#' @param oppmat a matrix with two columns and K rows, where K is the length of 
#'        the list oppind. The kth row contains values v1 and v2 (i.e., 
#'        v1=oppmat[k,1] and v2=oppmat[k,2]) such that for any variable var in
#'        oppind[[k]], var_v1 and !var_v2 are equivalent. v1 should be the 
#'        prefered return value.
#' @param oppind a list of vectors of variables. Each vector oppind[[k]] contains 
#'        variables var such that var_v1 and !var_v2 are equivalent, where v1 
#'        and v2 form the kth row of oppmat, v1=oppmat[k,1] and v2=oppmat[k,2]
#' @param feats vector of all possible features
#' @return a condition of the form var_val2 such that var_val1 is equivalent to 
#'         var_val2. If var_val1 cannot be simplified (no equivalence class 
#'         supplied), then val2=val1. If it can be, var2 will be v1, the first 
#'         entry of the respective row of oppmat
simplifyCondition <- function(cond, oppmat, oppind, feats){
  split <- strsplit(cond, "_")[[1]]
  if (tail(split, 1) != "neg") {  # if not negation, then return original
    return(cond)
  } else {
    fg <- split[1]
    oppmat <- matrix(oppmat[which(unlist(lapply(oppind, function(x) length(which(x==fg))>0))), ], ncol=2)
    val <- split[length(split)-1]
    ind <- as.matrix(which(oppmat==val, arr.ind = T))
    if (nrow(oppmat)>0) {
      if (ncol(ind) > 1) {
        for (i in 1:nrow(ind)) {
          row <- ind[i, 1]
          col <- ifelse(ind[i, 2]==1, yes=2, no=1)  # other column is oppmat value
          feat <- paste(fg, oppmat[row, col], sep="_")
          if (feat %in% feats) {  # check if this oppmat value is one of the possible features
            return(feat)
          }
        }
      } else {
        col <- ifelse(ind==1, yes=2, no=1)  # other column is oppmat value
        feat <- paste(fg, oppmat[1, col], sep="_")
        feat <- paste(fg, oppmat[row, col], sep="_")
        if (feat %in% feats) {  # check if this oppmat value is one of the possible features
          return(feat)
        }
      }
    }
  }
  return(cond)  # not one of the possible features
}


#' Get labels for features
#'
#' Get the labels of a features to use in graphs
#'
#' @param feat feature or vector of features to get the label of; should be of the form "feature_value"
#' @param labels_df dataframe of unique feature and corresponding labels; first column is the feature, second column is corresponding label
#' @param neg_label prefix to use for negative features
#' @return label to use for feat
getLabel <- function(feat, labels_df, neg_label){
  label <- c()
  for(i in 1:length(feat)){
    neg <- ""
    split_feat <- strsplit(feat[i], "_")[[1]]
    if ( tail(split_feat, 1) == "neg" ){
      neg <- neg_label
      feat[i] <- paste(split_feat[1:(length(split_feat)-1)], collapse="_")
    }
    label[i] <- paste(neg, labels_df[labels_df[,1]==feat[i],2], sep="")
  }
  return(label)
}


#' Number of cases satisfying rule
#'
#' Get the number of cases that satisfy a rule
#'
#' @param rule rule to check. In the format feature_value or feature, possibly
#'             with a "_neg" suffix
#' @param df dataframe or matrix of data (excluding outcome)
#' @param Y vector of outcomes. If supplied, the function will count the number
#'          of cases that both satisfy rule and have outcome equal to y_val
#' @param y_val value Y should take in order to be counted. Ignored if Y=NULL
#' @return number of cases in data that satisfies rule (and Y=y_val, if Y is not
#'         NULL)
numCases <- function(rule, df, Y=NULL, y_val=1){
  feats <- c() # Features to check
  values <- c() # Values of features
  for(i in 1:length(rule)){ # Loops through all features in rule to get feature (w/o _neg) and value
    feats[i] <- rule[i]
    values[i] <- 1
    split_feat <- strsplit(feats[i], "_")[[1]]
    
    if ( tail(split_feat, 1) == "neg" ) {
      values[i] <- 0
      feats[i] <- paste(split_feat[1:(length(split_feat)-1)], collapse="_")
    }
  }
  if(is.null(Y)){
    val <- sum(apply(df[,feats] == matrix(rep(values, nrow(df)), ncol=length(values), byrow=T), 1, all)) # Number of cases that satisfy rule
  } else { # must satisfy Y
    val <- sum(apply(cbind(Y, df[,feats]) == matrix(rep(c(y_val,values), nrow(df)), ncol=(length(values)+1), byrow=T), 1, all)) # Number of cases that satisfy rule and Y
  }
  return(val)
}


# Discretize variable
#
# Creates a dummy variable for being above or below a specified quantile
#
# @param x vector to be discretized
# @param q quantile
# @return an indicator for above and below quantile
discretize <- function(x, q, leq=TRUE) {
  if(leq){
    return(as.numeric(x < quantile(x, q, na.rm=T)))
  } else {
    return(as.numeric(x >= quantile(x, q, na.rm=T)))
  }
}


# Get label for a feature/variable
#
# Get the label for a feature/variable
#
#' @param feat the feature for which to get the label, either of the form 
#'             "feature_value" or "feature"
#' @param feats a vector of features
#' @param flabels a vector of labels corresponding to feats (in the same order)
#' @param vals a vector of possible values that features can take on
#' @param vlabels a vector of labels corresponding to vals (in the same order)
#' @return the label corresponding to feat in the format 
#'         "featureLabel_valueLabel". If feat has no suffix (i.e., no value),
#'         then will return just the label for the feature, "featureLabel".
get_label_value <- function(feat, feats, flabels, vals, vlabels) {
  split <- strsplit(feat, "_")[[1]]
  flab <- flabels[which(feats==split[1])]
  if (length(split) > 1) {
    vlab <- vlabels[which(vals==split[2])]
    return(paste0(flab, " (", vlab, ")"))
  } else {
    return(flab)
  }
}


#' Get proportion of true positives
#' 
#' Get the 2.5th, 50th, and 97.5th quantiles of the proportion of true positives
#' (as a proportion of the total number of observations -- the positive portion 
#' of the "coverage" statistic defined in the paper) over all bootstraps for 
#' each rule in a group of rules all of the same length
#' 
#' @param rules a dataframe or matrix of rules, all of the same length. Each row
#'              corresponds to a rule, and each element corresponds to a 
#'              condition in the rule.
#' @param allIndices a list of indices of the bootstraps
#' @param reps the number of bootstrapped repitions
#' @param df a dataframe of the data (excluding outcome)
#' @param Y the outcome
#' @return a dataframe of quantiles of true positives. Each row corresponds to
#'         the rule in the same row of rules. The first column is the 2.5th 
#'         quantile, the second is the median, and the third is the 97.5th
.getTP <- function(rules, allIndices, reps, df, Y){
  stats <- c()
  for(i in 1:nrow(rules)){ # Loop through rules
    coverage <- c()
    for(j in 1:reps){ # Loop through bootstrap samples
      ind <- allIndices[[j]]
      coverage[j] <- numCases(rule=unlist(rules[i,]), df=df[ind,], Y=Y[ind]) # number of true positives in that bootstrap
    }
    stats <- rbind(stats,
                   c(quantile(coverage, .025), median(coverage), quantile(coverage, .975))/length(Y))
  }
  colnames(stats) <- c("min", "median", "max") # min=.025 quantile, max=.975 quantile
  return(data.frame(stats))
}


#' Get proportion of false positives
#' 
#' Get the 2.5th, 50th, and 97.5th quantiles of the proportion of false positives
#' (as a proportion of the total number of observations -- the negative portion 
#' of the "coverage" statistic defined in the paper) over all bootstraps for 
#' each rule in a group of rules all of the same length
#' 
#' @param rules a dataframe or matrix of rules, all of the same length. Each row
#'              corresponds to a rule, and each element corresponds to a 
#'              condition in the rule.
#' @param allIndices a list of indices of the bootstraps
#' @param reps the number of bootstrapped repitions
#' @param df a dataframe of the data (excluding outcome)
#' @param Y the outcome
#' @return a dataframe of quantiles of false positives. Each row corresponds to
#'         the rule in the same row of rules. The first column is the 2.5th 
#'         quantile, the second is the median, and the third is the 97.5th
.getFP <- function(rules, allIndices, reps, df, Y){
  stats <- c()
  for(i in 1:nrow(rules)){ # Loop through rules
    coverage <- c()
    for(j in 1:reps){ # Loop through bootstrap samples
      ind <- allIndices[[j]]
      coverage[j] <- numCases(rule=unlist(rules[i,]), df=df[ind,], Y=Y[ind], y_val=0) # number of false positives in that bootstrap
    }
    stats <- rbind(stats,
                   c(quantile(coverage, .025), median(coverage), quantile(coverage, .975))/length(Y))
  }
  colnames(stats) <- c("min", "median", "max") # min=.025 quantile, max=.975 quantile
  return(data.frame(stats))
}


#' Make data frame for circlize
#'
#' Makes a data frame containing all the information needed to make a chord 
#' diagram as described in the paper from a rule set. Note: the function was
#' originally written to allow for multiple rule sets to accommodate an earlier
#' versions of the paper, where we propose constructing a different chord 
#' diagram. This function can still be used to produce the new proposed chord
#' diagram, but will have unused outputs (frequency and degree of interaction).
#'
#' @param allRuleSets a list of all the rule sets produced using bootstrapping,
#'                    or a single rule set
#' @param featureGroups a data frame whose first column is feature names 
#'                      (without values) as they appear in the rule sets and
#'                      whose second column is the corresponding labels that
#'                      will appear in the chord diagram
#' @param maxLen the maximum possible length of a rule
#' @param minProp the minimum proportion of times an interaction must appear to
#'                be included. Should be set to 0 if allRuleSets is a single 
#'                rule
#' @return data frame whose columns are the features to be connected (ones
#'         appearing in the same rule at least once), the number of times they
#'         appear together, the degree of interaction (length of the rules in 
#'         which they appear together), and a unique index that can then be 
#'         mapped to a vector of specified colors
.get_df_chord <- function(allRuleSets, featureGroups, maxLen, minProp){
  # create separate lists of rules for each length
  allRules <- vector(mode="list", length=maxLen)
  for (ruleSet in allRuleSets) {
    for (rule in ruleSet) {
      allRules[[length(rule)]] <- c(allRules[[length(rule)]],
                                    paste(sort(sapply(rule, function(x)
                                      featureGroups[featureGroups[,1]==.getFeatGroup(x), 2][1])), collapse="__"))
    }
  }
  
  # tabulate interactions
  counts <- vector(mode="list", length=maxLen)
  for (len in 1:maxLen) {
    counts[[len]] <- table(allRules[[len]])
    counts[[len]] <- counts[[len]][counts[[len]] >= minProp*length(allRuleSets)]
  }
  
  # make dfs
  df <- c()
  if (length(counts[[1]]) > 0) {
    for (i in 1:length(counts[[1]])) {
      count <- counts[[1]][i]
      df <- rbind.data.frame(df, c(names(count), names(count), count, 1, i))
    }
  }
  
  # interactions
  for (len in 2:maxLen) {
    if (length(counts[[len]]) > 0) {
      for (i in 1:length(counts[[len]])) {
        count <- counts[[len]][i]
        feats <- strsplit(names(count), "__")[[1]]
        for (j in 2:length(feats)) {
          df <- rbind.data.frame(df, c(feats[1], feats[j], count, len, max(0, as.numeric(df[df[, 4]==(len-1), 5]))+i))
        }
      }
    }
  }
  
  colnames(df) <- c("from", "to", "freq", "deg", "color_ind")
  
  return(df)
}


#' Get the feature corresponding to a rule
#'
#' Get the name of the feature corresponding to a rule as it appears in the 
#' rule. For example, .getFeatGroup(feat_val) will return "feat"
#'
#' @param rule the rule whose feature we wish to extract
#' @return the name of the feature corresponding to rule
.getFeatGroup <- function(rule) {
  return(strsplit(rule, "_")[[1]][1])
}


#' Make a chord diagram
#'
#' Makes a chord diagram of a single rule set
#'
#' @param ruleSet the rule set to plot, formatted as a list of rules
#' @param featureGroups the featureGroups input for .get_df_chord 
#' @param linkColors a vector of colors for the links
#' @param gridColors the color of the arcs
#' @param maxLen the maximum allowed length of a rule
#' @param textSize a graphical parameter for the cex of the text
#' @param side_mar a graphical parameter for adding white space to the sides of
#'                 the plot
#' @param top_mar a graphical parameter for adding white space to the top of
#'                 the plot
#' @return a chord diagram of the rule set
plot_chord <- function(ruleSet, featureNames, featureGroups,
                       linkColors, gridColors, bgLinkColor="gray",
                       minProp=0, textSize=1, line_arg=1, side_mar=0, top_mar=0){
  
  maxLen <- max(unlist(lapply(ruleSet, length)))  # maximum length of a rule
  
  df <- .get_df_chord(list(ruleSet), featureGroups, maxLen=maxLen, minProp=minProp)

  # plot
  circlize::circos.clear()
  circlize::circos.par(gap.after = 10)
  par(mar = c(0, side_mar, top_mar, side_mar), cex=textSize)
  
  colors <- rep(0, times=nrow(df))
  for (len in 1:maxLen) {
    ind <- which(df[, "deg"]==len)
    if (length(ind) > 0) {
      colors[ind] <- linkColors[as.numeric(df[ind, "color_ind"])]
    }
  }
  
  circlize::chordDiagram(df[, 1:3],
                         link.sort=T,
                         grid.col = gridColors, col = colors,
                         annotationTrack = c("name", "grid"),
                         annotationTrackHeight = c(0.03, 0.05),
                         self.link = 1)
  circlize::circos.clear()
}




#' Make a t-SNE plot
#' 
#' Makes a t-SNE plot of the data and using a rule set. Color codes based on
#' actual outcome and symbol codes based on classification outcome
#' 
#' @param df data, excluding outcome
#' @param Y outcome
#' @param rule set
#' @param caseColors a vector of colors, the first for Y=1 the second  for Y=0
#' @param symb a numeric vector to determine symbol type (pch for plot), the 
#'             first for Yhat=1 and the second for Yhat=0
#' @param pointSize graphical parameter for size of points (cex for plot)
#' @param textSize graphical parameter for size of text (cex for legend)
#' @param bottom_buffer graphical parameter for adding white space to the 
#'                      bottom of the plot to make room for legend
#' @param all_buffer graphical parameter for adding white space around plot
#' @param legend_under_plot logical for whether the legend should be under the
#'                          plot. If false, the legend will be inside the plot
#' @param legend_bg_col the background color of the legend
#' @param legend_offset a vector of how much to offset the legend along each
#'                      axis
#' @param jitter_factor the factor input for the jitter function
#' @param jitter_amount the amount input for the jitter function
#' @param max_iter the maximum iteration to run tsne for
#' @param highlight the index (or indices) of the rule in A to be highlighted.
#'                  The resulting graph will highlight A[highlight]
#' @return a t-SNE plot 
plot_tsne <- function(df, Y, A, caseColors, symb=c(1, 4),
                           pointSize=1, textSize=1, 
                           bottom_buffer=1.25, all_buffer=1,
                           legend_under_plot=T, legend_bg_col="transparent", 
                           legend_offset=c(0,0), legend_position="bottomright", 
                           jitter_factor=1, jitter_amount=NULL,
                           max_iter=1000,
                           highlight=NULL, box_color=NULL){
  
  ## Include only outcome and features in A
  incl <- c() # variables to include
  for (rule in A) {
    for (cond in rule) {
      split <- strsplit(cond, "_")[[1]]
      is_neg <- as.numeric(split[length(split)]=="neg")
      incl <- c(incl, paste(split[1:(length(split)-is_neg)], collapse="_"))
    }
  }
  
  ## Run t-SNE
  train <- cbind.data.frame(Y, df[, colnames(df)[which(colnames(df) %in% incl)]])
  tsne <- Rtsne::Rtsne(train, dims=2, perplexity=5, verbose=F, max_iter=max_iter, check_duplicates=F)#, partial_pca=T)
  coord_nojitter <- tsne$Y
  
  Yhat <- getYhat(df, A) # Classification based on A
  
  ## Jitter
  tsne$Y <- jitter(tsne$Y, factor=jitter_factor, amount=jitter_amount)
  ## Standardize coordinates
  rangeX <- max(tsne$Y[,1])-min(tsne$Y[,1])
  rangeY <- max(tsne$Y[,2])-min(tsne$Y[,2])
  tsne$Y[,1] <- tsne$Y[,1]/rangeX*100
  tsne$Y[,2] <- tsne$Y[,2]/rangeY*100
  
  
  ## Highlight; plot before points otherwise will cover them up 
  boxWidth = (max(tsne$Y[,1]) - min(tsne$Y[,1]))/25
  boxHeight = (max(tsne$Y[,2]) - min(tsne$Y[,2]))/25
  par(mar=c(bottom_buffer,0,0,0)+all_buffer)
  plot(2*max(tsne$Y), xlim=c(min(tsne$Y[,1]), max(tsne$Y[,1])), ylim=c(min(tsne$Y[,2]), max(tsne$Y[,2])),
       xlab="", ylab="",
       xaxt='n', yaxt='n')
  for (a in A[highlight]) {  # For each rule we want to plot
    if (length(rule)==1) {  # One condition
      # Features and values
      split <- strsplit(a, "_")[[1]]
      is_neg <- as.numeric(split[length(split)]=="neg")
      feature <- paste(split[1:(length(split)-is_neg)], collapse="_")
      value <- as.numeric(!is_neg)
      
      tsne_pos <- tsne$Y[df[,feature]==value,]
      Yhat[Yhat==0] <- as.numeric(df[,feature]==value)[Yhat==0] # If satisfies rule, then classified as positive (if doesn't, not recoded)
      
    } else {  # multiple conditions
      split_conds <- strsplit(a, "_")
      features <- c()
      values <- c()
      for (j in 1:length(a) ) {
        split <- split_conds[[j]]
        is_neg <- as.numeric(split[length(split)]=="neg")
        features[j] <- paste(split[1:(length(split)-is_neg)], collapse="_")
        values[j] <- as.numeric(!is_neg)
      }
      values_matrix <- matrix(rep(values, nrow(df)), ncol=length(values), byrow=T)
      tsne_pos <- tsne$Y[ apply(df[,features]==values_matrix, 1, all), ] # tsne coordinates of cases that satisfy the rule
      Yhat[Yhat==0] <- as.numeric(apply(df[,features]==values_matrix, 1, all))[Yhat==0]
    }
    
    if (!is.null(nrow(tsne_pos))) {  # Check multiple rows
      for (j in 1:nrow(tsne_pos)) {  # For each case that satisfies the rule
        # Draw a box around it
        rect(xleft=tsne_pos[j,1]-boxWidth/2, xright=tsne_pos[j,1]+boxWidth/2,
             ybottom=tsne_pos[j,2]-boxHeight/2, tsne_pos[j,2]+boxHeight/2,
             border=NA,
             lwd=10,
             col="grey")
      }
    } else { # Vector, only 1 "row"
      rect(xleft=tsne_pos[1]-boxWidth/2, xright=tsne_pos[1]+boxWidth/2,
           ybottom=tsne_pos[2]-boxHeight/2, tsne_pos[2]+boxHeight/2,
           border=NA,
           lwd=10,
           col=boxColor)
    }
  }
  
  ## Plot points
  par(new=T, mar=c(bottom_buffer,0,0,0)+all_buffer)
  plot(tsne$Y,
       col=ifelse(Y==1, yes=caseColors[1], no=caseColors[2]),  # color code by outcome
       pch=ifelse(Y==Yhat, yes=symb[1], no=symb[2]),  # symbol code by correct classification 
       lwd=ifelse(Y==Yhat, yes=0, no=1),
       cex = pointSize,
       xlim=c(min(tsne$Y[,1]), max(tsne$Y[,1])), ylim=c(min(tsne$Y[,2]), max(tsne$Y[,2])),
       xaxt='n', yaxt='n', xlab="", ylab="")
  
  ## Legend
  if(legend_under_plot){
    legend(x=min(tsne$Y[,1])+legend_offset[1], y=min(tsne$Y[,2])-1+legend_offset[2], ncol=2,
           legend=c("True positive", "True negative", "False negative", "False positive"),
           bg=legend_bg_col, col=c(caseColors, caseColors),
           pch=c(symb[1], symb[1], symb[2], symb[2]), box.lty = 0, cex=textSize,
           xpd=T)
  } else { ## legend inside plot
    legend(legend_position, inset=legend_offset,
           legend=c(shade_name, "Y=1", "Y=0"),
           bg=legend_bg_col, col=c(boxColor, caseColors),
           pch=c(15, rep(symb, 2)), box.lty = 0, cex=textSize)
  }
}


